{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "41fb78a4-5aa1-4288-9cc2-6f742062f0a3",
      "metadata": {
        "id": "41fb78a4-5aa1-4288-9cc2-6f742062f0a3"
      },
      "source": [
        "# Fine Tuning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "n9sehdR5Cv6A",
      "metadata": {
        "id": "n9sehdR5Cv6A"
      },
      "outputs": [],
      "source": [
        "!pip install gensim\n",
        "!pip install --upgrade datasets==3.6.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b9bf9a6a",
      "metadata": {
        "id": "b9bf9a6a"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import math\n",
        "import random\n",
        "import json\n",
        "import pickle\n",
        "import re\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "from tqdm import tqdm\n",
        "from pathlib import Path\n",
        "from openai import OpenAI\n",
        "from datetime import datetime\n",
        "from dotenv import load_dotenv\n",
        "import matplotlib.pyplot as plt\n",
        "from huggingface_hub import login\n",
        "from sklearn.svm import LinearSVR\n",
        "from gensim.models import Word2Vec\n",
        "from IPython.display import display\n",
        "from transformers import AutoTokenizer\n",
        "from gensim.utils import simple_preprocess\n",
        "from collections import Counter, defaultdict\n",
        "from sklearn.linear_model import LinearRegression\n",
        "from sklearn.ensemble import RandomForestRegressor\n",
        "from concurrent.futures import ProcessPoolExecutor\n",
        "from datasets import Dataset, DatasetDict, load_dataset\n",
        "from sklearn.metrics import mean_squared_error, r2_score\n",
        "from sklearn.feature_extraction.text import CountVectorizer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "261d16fa",
      "metadata": {
        "id": "261d16fa"
      },
      "outputs": [],
      "source": [
        "load_dotenv(override=True)\n",
        "openai_key = os.environ.get(\"OPENAI_API_KEY\")\n",
        "\n",
        "#anthropic_key = os.environ.get(\"ANTHROPIC_API_KEY\")\n",
        "\n",
        "hf_token = os.environ.get(\"HF_TOKEN\")\n",
        "print(hf_token)\n",
        "\n",
        "if hf_token:\n",
        "    print(\"Loggin in...\")\n",
        "    login(hf_token, add_to_git_credential=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2cdfe762-3200-4459-981e-0ded7c14b4de",
      "metadata": {
        "id": "2cdfe762-3200-4459-981e-0ded7c14b4de"
      },
      "outputs": [],
      "source": [
        "GREEN = \"\\033[92m\"\n",
        "YELLOW = \"\\033[93m\"\n",
        "RED = \"\\033[91m\"\n",
        "RESET = \"\\033[0m\"\n",
        "COLOR_MAP = {\"red\": RED, \"orange\": YELLOW, \"green\": GREEN}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0832e74b-2779-4822-8e6c-4361ec165c7f",
      "metadata": {
        "id": "0832e74b-2779-4822-8e6c-4361ec165c7f"
      },
      "outputs": [],
      "source": [
        "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
        "\n",
        "MIN_CHARS = 300\n",
        "MIN_TOKENS = 150\n",
        "MAX_TOKENS = 160\n",
        "CEILING_CHARS = MAX_TOKENS * 7\n",
        "\n",
        "class Item:\n",
        "    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
        "    PREFIX = \"Price is $\"\n",
        "    QUESTION = \"How much does this cost to the nearest dollar?\"\n",
        "    REMOVALS = ['\"Batteries Included?\": \"No\"', '\"Batteries Included?\": \"Yes\"', '\"Batteries Required?\": \"No\"', '\"Batteries Required?\": \"Yes\"', \"By Manufacturer\", \"Item\", \"Date First\", \"Package\", \":\", \"Number of\", \"Best Sellers\", \"Number\", \"Product \"]\n",
        "\n",
        "    def __init__(self, data, price):\n",
        "        self.title = data[\"title\"]\n",
        "        self.price = price\n",
        "        self.category = data.get(\"category\", \"Unknown\")\n",
        "        self.token_count = 0\n",
        "        self.details = None\n",
        "        self.prompt = None\n",
        "        self.include = False\n",
        "        self.parse(data)\n",
        "\n",
        "    def scrub_details(self):\n",
        "        details = self.details\n",
        "\n",
        "        for remove in self.REMOVALS:\n",
        "            details = details.replace(remove, \"\")\n",
        "\n",
        "        return details\n",
        "\n",
        "    def scrub(self, text):\n",
        "        text = re.sub(r'[:\\[\\]\"{}【】\\s]+', ' ', text).strip()\n",
        "        text = text.replace(\" ,\", \",\").replace(\",,,\",\",\").replace(\",,\",\",\")\n",
        "        words = text.split(\" \")\n",
        "        select = [word for word in words if len(word) < 7 or not any(char.isdigit() for char in word)]\n",
        "        return \" \".join(select)\n",
        "\n",
        "    def parse(self, data):\n",
        "        contents = '\\n'.join(data.get(\"description\", []))\n",
        "\n",
        "        if contents:\n",
        "            contents += '\\n'\n",
        "\n",
        "        features = '\\n'.join(data.get(\"features\", []))\n",
        "        if features:\n",
        "            contents += features + '\\n'\n",
        "\n",
        "        self.details = data.get(\"details\")\n",
        "        if self.details:\n",
        "            contents += self.scrub_details() + '\\n'\n",
        "\n",
        "        if len(contents) > MIN_CHARS:\n",
        "            contents = contents[:CEILING_CHARS]\n",
        "            text = f\"{self.scrub(self.title)}\\n{self.scrub(contents)}\"\n",
        "            tokens = self.tokenizer.encode(text, add_special_tokens=False)\n",
        "\n",
        "            if len(tokens) > MIN_TOKENS:\n",
        "                tokens = tokens[:MAX_TOKENS]\n",
        "                text = self.tokenizer.decode(tokens)\n",
        "                self.make_prompt(text)\n",
        "                self.include = True\n",
        "\n",
        "    def make_prompt(self, text):\n",
        "        self.prompt = f\"{self.QUESTION}\\n\\n{text}\\n\\n\"\n",
        "        self.prompt += f\"{self.PREFIX}{str(round(self.price))}.00\"\n",
        "        self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False))\n",
        "\n",
        "    def test_prompt(self):\n",
        "        return self.prompt.split(self.PREFIX)[0] + self.PREFIX\n",
        "\n",
        "    def __repr__(self):\n",
        "        return f\"<{self.title} = ${self.price}>\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "aa478d70",
      "metadata": {
        "id": "aa478d70"
      },
      "outputs": [],
      "source": [
        "MIN_PRICE = 0.5\n",
        "CHUNK_SIZE = 1000\n",
        "MAX_PRICE = 999.49\n",
        "\n",
        "class ItemLoader:\n",
        "    def __init__(self, name):\n",
        "        self.name = name\n",
        "        self.dataset = None\n",
        "\n",
        "    def from_datapoint(self, datapoint):\n",
        "        try:\n",
        "            price_str = datapoint.get(\"price\")\n",
        "            if price_str:\n",
        "                price = float(price_str)\n",
        "                if MIN_PRICE <= price <= MAX_PRICE:\n",
        "                    item = Item(datapoint, price)\n",
        "                    if item.include:\n",
        "                        return item\n",
        "        except ValueError:\n",
        "            return None\n",
        "\n",
        "    def from_chunk(self, chunk):\n",
        "        batch = []\n",
        "        for datapoint in chunk:\n",
        "            item = self.from_datapoint(datapoint)\n",
        "\n",
        "            if item:\n",
        "                batch.append(item)\n",
        "\n",
        "        return batch\n",
        "\n",
        "    def chunk_generator(self):\n",
        "        size = len(self.dataset)\n",
        "        for start in range(0, size, CHUNK_SIZE):\n",
        "            yield self.dataset.select(range(start, min(start + CHUNK_SIZE, size)))\n",
        "\n",
        "    def load_in_parallel(self, workers):\n",
        "        results = []\n",
        "        chunk_count = (len(self.dataset) // CHUNK_SIZE) + 1\n",
        "\n",
        "        with ProcessPoolExecutor(max_workers=workers) as pool:\n",
        "            for batch in tqdm(pool.map(self.from_chunk, self.chunk_generator()), total=chunk_count):\n",
        "                results.extend(batch)\n",
        "\n",
        "        for result in results:\n",
        "            result.category = self.name\n",
        "\n",
        "        return results\n",
        "\n",
        "    def load(self, workers=8):\n",
        "        self.dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", f\"raw_meta_{self.name}\", split=\"full\", trust_remote_code=True)\n",
        "        start = datetime.now()\n",
        "        print(f\"Loading {self.dataset}\")\n",
        "        results = self.load_in_parallel(workers)\n",
        "        duration = (datetime.now() - start).total_seconds() / 60\n",
        "        print(f\"Completed {self.name} with {len(results):,} items in {duration:.1f} mins\")\n",
        "        return results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6XywRUiUro69",
      "metadata": {
        "id": "6XywRUiUro69"
      },
      "outputs": [],
      "source": [
        "class Tester:\n",
        "\n",
        "    def __init__(self, predictor, data, title=None, size=250):\n",
        "        self.predictor = predictor\n",
        "        self.data = data\n",
        "        self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
        "        self.size = size\n",
        "        self.guesses = []\n",
        "        self.truths = []\n",
        "        self.errors = []\n",
        "        self.sles = []\n",
        "        self.colors = []\n",
        "\n",
        "    def color_for(self, error, truth):\n",
        "        if error < 40 or error / truth < 0.2:\n",
        "            return \"green\"\n",
        "\n",
        "        if error < 80 or error / truth < 0.4:\n",
        "            return \"orange\"\n",
        "\n",
        "        return \"red\"\n",
        "\n",
        "    def run_datapoint(self, index):\n",
        "        datapoint = self.data[index]\n",
        "        guess = self.predictor(datapoint)\n",
        "        truth = datapoint.price\n",
        "        error = abs(guess - truth)\n",
        "        log_error = math.log(truth + 1) - math.log(guess + 1)\n",
        "        sle = log_error ** 2\n",
        "        color = self.color_for(error, truth)\n",
        "        name = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40] + \"...\"\n",
        "        self.guesses.append(guess)\n",
        "        self.truths.append(truth)\n",
        "        self.errors.append(error)\n",
        "        self.sles.append(sle)\n",
        "        self.colors.append(color)\n",
        "        print(f\"{COLOR_MAP[color]}{index + 1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {name}{RESET}\")\n",
        "\n",
        "    def chart(self, title):\n",
        "        plt.figure(figsize=(12, 8))\n",
        "        max_val = max(max(self.truths), max(self.guesses))\n",
        "        plt.plot([0, max_val], [0, max_val], color=\"deepskyblue\", lw=2, alpha=0.6)\n",
        "        plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
        "        plt.xlabel(\"Ground Truth\")\n",
        "        plt.ylabel(\"Model Estimate\")\n",
        "        plt.xlim(0, max_val)\n",
        "        plt.ylim(0, max_val)\n",
        "        plt.title(title)\n",
        "        plt.show()\n",
        "\n",
        "    def report(self):\n",
        "        average_error = sum(self.errors) / self.size\n",
        "        rmsle = math.sqrt(sum(self.sles) / self.size)\n",
        "        hits = sum(1 for color in self.colors if color == \"green\")\n",
        "        title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits / self.size * 100:.1f}%\"\n",
        "        self.chart(title)\n",
        "\n",
        "    def run(self):\n",
        "        for index in range(self.size):\n",
        "            self.run_datapoint(index)\n",
        "\n",
        "        self.report()\n",
        "\n",
        "    @classmethod\n",
        "    def test(cls, function, data):\n",
        "        cls(function, data).run()\n",
        "\n",
        "def get_price(s):\n",
        "    s = s.replace(\"$\", \"\").replace(\",\", \"\")\n",
        "    match = re.search(r\"[-+]?\\d*\\.?\\d+\", s)\n",
        "    return float(match.group()) if match else 0.0"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9856f570",
      "metadata": {
        "id": "9856f570"
      },
      "source": [
        "## Data"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3XTxVhq0xC8Z",
      "metadata": {
        "id": "3XTxVhq0xC8Z"
      },
      "source": [
        "### Load Catalogs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2bd6fc25-77c4-47a6-a2d2-ce80403f3c22",
      "metadata": {
        "id": "2bd6fc25-77c4-47a6-a2d2-ce80403f3c22"
      },
      "outputs": [],
      "source": [
        "catalog_labels = [\n",
        "    \"All_Beauty\",\n",
        "    # \"Automotive\",\n",
        "    # \"Electronics\",\n",
        "    # \"Office_Products\",\n",
        "    # \"Tools_and_Home_Improvement\",\n",
        "    # \"Cell_Phones_and_Accessories\",\n",
        "    # \"Toys_and_Games\",\n",
        "    \"Appliances\",\n",
        "    \"Musical_Instruments\",\n",
        "    \"Software\",\n",
        "    \"Handmade_Products\"\n",
        "]\n",
        "curated_pool = []\n",
        "\n",
        "for label in catalog_labels:\n",
        "    print(\"Loading \" + label)\n",
        "    loader = ItemLoader(label)\n",
        "    curated_pool.extend(loader.load())\n",
        "\n",
        "print(f\"Total curated items: {len(curated_pool):,}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b66b59c2-80b2-4d47-b739-c59423cf9d7d",
      "metadata": {
        "id": "b66b59c2-80b2-4d47-b739-c59423cf9d7d"
      },
      "outputs": [],
      "source": [
        "price_series = [item.price for item in curated_pool]\n",
        "token_series = [item.token_count for item in curated_pool]\n",
        "category_tally = Counter(item.category for item in curated_pool)\n",
        "summary_frame = pd.DataFrame({\"price\": price_series, \"tokens\": token_series})\n",
        "\n",
        "display(summary_frame.describe())\n",
        "display(pd.DataFrame.from_dict(category_tally, orient=\"index\", columns=[\"count\"]).sort_values(\"count\", ascending=False))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e9620ed3-205e-48ee-b67a-e56b30bf6b6b",
      "metadata": {
        "id": "e9620ed3-205e-48ee-b67a-e56b30bf6b6b"
      },
      "outputs": [],
      "source": [
        "price_slots = defaultdict(list)\n",
        "for item in curated_pool:\n",
        "    key = round(item.price)\n",
        "    if 1 <= key <= 999:\n",
        "        price_slots[key].append(item)\n",
        "\n",
        "slot_counts = {k: len(v) for k, v in price_slots.items()}\n",
        "print(f\"Slots populated: {len(slot_counts)}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "834a3c4b-fc9c-4bc7-b6b9-bdf7e8d6d585",
      "metadata": {
        "id": "834a3c4b-fc9c-4bc7-b6b9-bdf7e8d6d585"
      },
      "outputs": [],
      "source": [
        "random.seed(123)\n",
        "np.random.seed(123)\n",
        "balanced_bundle = []\n",
        "\n",
        "for price in range(1, 1000):\n",
        "    bucket = price_slots.get(price, [])\n",
        "\n",
        "    if price >= 240:\n",
        "        balanced_bundle.extend(bucket)\n",
        "\n",
        "    elif len(bucket) <= 1200:\n",
        "        balanced_bundle.extend(bucket)\n",
        "\n",
        "    else:\n",
        "        weights = np.array([1 if item.category == \"Automotive\" else 5 for item in bucket], dtype=float)\n",
        "        weights /= weights.sum()\n",
        "        indices = np.random.choice(len(bucket), size=1200, replace=False, p=weights)\n",
        "        for idx in indices:\n",
        "            balanced_bundle.append(bucket[idx])\n",
        "\n",
        "print(f\"Balanced bundle size: {len(balanced_bundle):,}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a506f42c-81c0-4198-bc0b-1e0653620be8",
      "metadata": {
        "id": "a506f42c-81c0-4198-bc0b-1e0653620be8"
      },
      "outputs": [],
      "source": [
        "bundle_prices = [item.price for item in balanced_bundle]\n",
        "bundle_tokens = [item.token_count for item in balanced_bundle]\n",
        "bundle_categories = Counter(item.category for item in balanced_bundle)\n",
        "display(pd.Series(bundle_prices).describe())\n",
        "display(pd.DataFrame.from_dict(bundle_categories, orient=\"index\", columns=[\"count\"]).sort_values(\"count\", ascending=False))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5842ace6-332d-46da-a853-5ea5a2a1cf88",
      "metadata": {
        "id": "5842ace6-332d-46da-a853-5ea5a2a1cf88"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(12, 5))\n",
        "plt.hist(bundle_prices, bins=range(0, 1000, 10), color=\"midnightblue\", rwidth=0.8)\n",
        "plt.xlabel(\"Price\")\n",
        "plt.ylabel(\"Count\")\n",
        "plt.figure(figsize=(12, 5))\n",
        "plt.hist(bundle_tokens, bins=range(0, 300, 10), color=\"forestgreen\", rwidth=0.8)\n",
        "plt.xlabel(\"Tokens\")\n",
        "plt.ylabel(\"Count\")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "42ee0099-0d2a-4331-a01c-3462363a6987",
      "metadata": {
        "id": "42ee0099-0d2a-4331-a01c-3462363a6987"
      },
      "outputs": [],
      "source": [
        "random.seed(123)\n",
        "random.shuffle(balanced_bundle)\n",
        "test_target = min(2000, max(1, len(balanced_bundle) // 20))\n",
        "train_target = min(400_000, len(balanced_bundle) - test_target)\n",
        "train_items = balanced_bundle[:train_target]\n",
        "test_items = balanced_bundle[train_target:train_target + test_target]\n",
        "print(f\"Training set: {len(train_items):,}\")\n",
        "print(f\"Test set: {len(test_items):,}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1146d5a2-f93e-4fe9-864e-4ce7e01e257b",
      "metadata": {
        "id": "1146d5a2-f93e-4fe9-864e-4ce7e01e257b"
      },
      "outputs": [],
      "source": [
        "train_prompts = [item.prompt for item in train_items]\n",
        "train_prices = [item.price for item in train_items]\n",
        "test_prompts = [item.test_prompt() for item in test_items]\n",
        "test_prices = [item.price for item in test_items]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "31ca360d-5fc6-487a-91c6-d61758b2ff16",
      "metadata": {
        "id": "31ca360d-5fc6-487a-91c6-d61758b2ff16"
      },
      "outputs": [],
      "source": [
        "train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n",
        "test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n",
        "pricing_dataset = DatasetDict({\"train\": train_dataset, \"test\": test_dataset})"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "05e6ca7e-bf40-49f9-bffb-a5b22e5800d8",
      "metadata": {
        "id": "05e6ca7e-bf40-49f9-bffb-a5b22e5800d8"
      },
      "source": [
        "### Persist"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b0ff2fe3-78bf-49e3-a682-6a46742d010c",
      "metadata": {
        "id": "b0ff2fe3-78bf-49e3-a682-6a46742d010c"
      },
      "outputs": [],
      "source": [
        "storage_dir = Path(\"data\")\n",
        "storage_dir.mkdir(exist_ok=True)\n",
        "\n",
        "with open(storage_dir / \"balanced_train.pkl\", \"wb\") as f:\n",
        "    pickle.dump(train_items, f)\n",
        "\n",
        "with open(storage_dir / \"balanced_test.pkl\", \"wb\") as f:\n",
        "    pickle.dump(test_items, f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b2164662-9bc9-4a66-9e4e-a8a955a45753",
      "metadata": {
        "id": "b2164662-9bc9-4a66-9e4e-a8a955a45753"
      },
      "outputs": [],
      "source": [
        "pricing_dataset[\"train\"].to_parquet(storage_dir / \"balanced_train.parquet\")\n",
        "pricing_dataset[\"test\"].to_parquet(storage_dir / \"balanced_test.parquet\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6fe428a2-41c4-4f7f-a43f-e8ba2f344013",
      "metadata": {
        "id": "6fe428a2-41c4-4f7f-a43f-e8ba2f344013"
      },
      "source": [
        "## Baselines"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "qX0c_prppnyZ",
      "metadata": {
        "id": "qX0c_prppnyZ"
      },
      "source": [
        "### Stochastic Anchor"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7323252b-db50-4b8a-a7fc-8504bb3d218b",
      "metadata": {
        "id": "7323252b-db50-4b8a-a7fc-8504bb3d218b"
      },
      "outputs": [],
      "source": [
        "def stochastic_anchor(item):\n",
        "    return random.randrange(1, 1000)\n",
        "\n",
        "random.seed(123)\n",
        "Tester.test(stochastic_anchor, test_items[:250])"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "O0xVXRXkp9sQ",
      "metadata": {
        "id": "O0xVXRXkp9sQ"
      },
      "source": [
        "### Global Mean"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6a932b0e-ba6e-45d2-8436-b740c3681272",
      "metadata": {
        "id": "6a932b0e-ba6e-45d2-8436-b740c3681272"
      },
      "outputs": [],
      "source": [
        "train_price_values = [item.price for item in train_items]\n",
        "global_mean_price = sum(train_price_values) / len(train_price_values)\n",
        "\n",
        "def global_mean_estimator(item):\n",
        "    return global_mean_price\n",
        "\n",
        "Tester.test(global_mean_estimator, test_items[:250])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d3410bd4-98e4-42a6-a702-4423cfd034b4",
      "metadata": {
        "id": "d3410bd4-98e4-42a6-a702-4423cfd034b4"
      },
      "outputs": [],
      "source": [
        "def parse_features(raw):\n",
        "    if not raw:\n",
        "        return {}\n",
        "    try:\n",
        "        return json.loads(raw)\n",
        "    except json.JSONDecodeError:\n",
        "        return {}\n",
        "for item in train_items:\n",
        "    item.features = parse_features(item.details)\n",
        "for item in test_items:\n",
        "    item.features = parse_features(item.details)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "44537051-7b4e-4b8c-95a7-a989ea51e517",
      "metadata": {
        "id": "44537051-7b4e-4b8c-95a7-a989ea51e517"
      },
      "source": [
        "### Feature Engineering"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "47d03b0b-4a93-4f9d-80ac-10f3fc11ccec",
      "metadata": {
        "id": "47d03b0b-4a93-4f9d-80ac-10f3fc11ccec"
      },
      "outputs": [],
      "source": [
        "def infer_weight(item):\n",
        "    payload = item.features.get(\"Item Weight\")\n",
        "    if not payload:\n",
        "        return None\n",
        "\n",
        "    parts = payload.split(\" \")\n",
        "    amount = float(parts[0])\n",
        "    unit = parts[1].lower()\n",
        "\n",
        "    if unit == \"pounds\":\n",
        "        return amount\n",
        "\n",
        "    if unit == \"ounces\":\n",
        "        return amount / 16\n",
        "\n",
        "    if unit == \"grams\":\n",
        "        return amount / 453.592\n",
        "\n",
        "    if unit == \"milligrams\":\n",
        "        return amount / 453592\n",
        "\n",
        "    if unit == \"kilograms\":\n",
        "        return amount / 0.453592\n",
        "\n",
        "    if unit == \"hundredths\" and len(parts) > 2 and parts[2].lower() == \"pounds\":\n",
        "        return amount / 100\n",
        "\n",
        "    return None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4d7b6f35-890c-4227-8990-6b62694a332d",
      "metadata": {
        "id": "4d7b6f35-890c-4227-8990-6b62694a332d"
      },
      "outputs": [],
      "source": [
        "def infer_rank(item):\n",
        "    payload = item.features.get(\"Best Sellers Rank\")\n",
        "    if not payload:\n",
        "        return None\n",
        "\n",
        "    values = list(payload.values()) if isinstance(payload, dict) else []\n",
        "    if not values:\n",
        "        return None\n",
        "\n",
        "    return sum(values) / len(values)\n",
        "\n",
        "top_brands = {\"nvidea\",\"hp\",\"dell\",\"lenovo\",\"samsung\",\"asus\",\"sony\",\"canon\",\"apple\",\"intel\"}\n",
        "\n",
        "def is_top_brand(item):\n",
        "    brand = item.features.get(\"Brand\")\n",
        "    return 1 if brand and brand.lower() in top_brands else 0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1a6e06f3-614f-4687-bd43-9ac03aaface8",
      "metadata": {
        "id": "1a6e06f3-614f-4687-bd43-9ac03aaface8"
      },
      "outputs": [],
      "source": [
        "train_weights = [infer_weight(item) for item in train_items]\n",
        "train_weights = [value for value in train_weights if value is not None]\n",
        "average_weight = sum(train_weights) / len(train_weights) if train_weights else 1.0\n",
        "train_ranks = [infer_rank(item) for item in train_items]\n",
        "train_ranks = [value for value in train_ranks if value is not None]\n",
        "average_rank = sum(train_ranks) / len(train_ranks) if train_ranks else 1_000_000.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d8dda552-8003-4fdc-b36a-7d0afa9b0b42",
      "metadata": {
        "id": "d8dda552-8003-4fdc-b36a-7d0afa9b0b42"
      },
      "outputs": [],
      "source": [
        "def build_features(item):\n",
        "    weight = infer_weight(item)\n",
        "    rank = infer_rank(item)\n",
        "\n",
        "    return {\n",
        "        \"weight\": weight if weight is not None else average_weight,\n",
        "        \"rank\": rank if rank is not None else average_rank,\n",
        "        \"text_length\": len(item.test_prompt()),\n",
        "        \"top_brand\": is_top_brand(item)\n",
        "    }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "189e959c-d70c-4509-bff6-1cbd8e8db637",
      "metadata": {
        "id": "189e959c-d70c-4509-bff6-1cbd8e8db637"
      },
      "outputs": [],
      "source": [
        "train_frame = pd.DataFrame([build_features(item) for item in train_items])\n",
        "train_frame[\"price\"] = [item.price for item in train_items]\n",
        "test_frame = pd.DataFrame([build_features(item) for item in test_items[:250]])\n",
        "test_frame[\"price\"] = [item.price for item in test_items[:250]]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6b1480e2-ed19-4d0e-bc5d-a00086d104a2",
      "metadata": {
        "id": "6b1480e2-ed19-4d0e-bc5d-a00086d104a2"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\"weight\", \"rank\", \"text_length\", \"top_brand\"]\n",
        "X_train = train_frame[feature_columns]\n",
        "y_train = train_frame[\"price\"]\n",
        "X_test = test_frame[feature_columns]\n",
        "y_test = test_frame[\"price\"]\n",
        "linear_model = LinearRegression()\n",
        "linear_model.fit(X_train, y_train)\n",
        "\n",
        "def linear_baseline(item):\n",
        "    return float(linear_model.predict(pd.DataFrame([build_features(item)]))[0])\n",
        "\n",
        "Tester.test(linear_baseline, test_items[:250])"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ga-f4JK7sPU2",
      "metadata": {
        "id": "ga-f4JK7sPU2"
      },
      "source": [
        "### NLP Baselines"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "de958a51-69ba-420c-84b7-d32765898fd2",
      "metadata": {
        "id": "de958a51-69ba-420c-84b7-d32765898fd2"
      },
      "outputs": [],
      "source": [
        "document_texts = [item.test_prompt() for item in train_items]\n",
        "price_targets = np.array([item.price for item in train_items])\n",
        "vectorizer = CountVectorizer(max_features=1000, stop_words=\"english\")\n",
        "X_matrix = vectorizer.fit_transform(document_texts)\n",
        "bow_model = LinearRegression()\n",
        "bow_model.fit(X_matrix, price_targets)\n",
        "\n",
        "def bow_predictor(item):\n",
        "  pred = float(bow_model.predict(vectorizer.transform([item.test_prompt()]))[0])\n",
        "  return max(pred, 0)\n",
        "\n",
        "Tester.test(bow_predictor, test_items[:250])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "QFDAoNnoRCk1",
      "metadata": {
        "id": "QFDAoNnoRCk1"
      },
      "outputs": [],
      "source": [
        "processed_docs = [simple_preprocess(text) for text in document_texts]\n",
        "word2vec_model = Word2Vec(sentences=processed_docs, vector_size=400, window=5, min_count=1, workers=4)\n",
        "\n",
        "def document_vector(text):\n",
        "    words = simple_preprocess(text)\n",
        "    vectors = [word2vec_model.wv[word] for word in words if word in word2vec_model.wv]\n",
        "\n",
        "    if not vectors:\n",
        "        return np.zeros(word2vec_model.vector_size)\n",
        "    return np.mean(vectors, axis=0)\n",
        "\n",
        "w2v_features = np.array([document_vector(text) for text in document_texts])\n",
        "svr_model = LinearSVR()\n",
        "svr_model.fit(w2v_features, price_targets)\n",
        "\n",
        "def w2v_predictor(item):\n",
        "    return float(svr_model.predict([document_vector(item.test_prompt())])[0])\n",
        "\n",
        "Tester.test(w2v_predictor, test_items[:250])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "kBVWisusQwDq",
      "metadata": {
        "id": "kBVWisusQwDq"
      },
      "outputs": [],
      "source": [
        "forest_model = RandomForestRegressor(n_estimators=200, random_state=123)\n",
        "forest_model.fit(X_train, y_train)\n",
        "\n",
        "def forest_predictor(item):\n",
        "    return float(forest_model.predict(pd.DataFrame([build_features(item)])[feature_columns])[0])\n",
        "\n",
        "Tester.test(forest_predictor, test_items[:250])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "wgth1KvMSEOb",
      "metadata": {
        "id": "wgth1KvMSEOb"
      },
      "outputs": [],
      "source": [
        "fine_tune_train = train_items[:200]\n",
        "fine_tune_validation = train_items[200:250]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "g7uz8SC5S3_s",
      "metadata": {
        "id": "g7uz8SC5S3_s"
      },
      "outputs": [],
      "source": [
        "def compose_messages(item, include_price=True):\n",
        "    system_message = \"You estimate prices of items. Reply only with the price\"\n",
        "    user_prompt = item.test_prompt().replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n",
        "    assistant_content = f\"Price is ${item.price:.2f}\" if include_price else \"Price is $\"\n",
        "    return [\n",
        "        {\"role\": \"system\", \"content\": system_message},\n",
        "        {\"role\": \"user\", \"content\": user_prompt},\n",
        "        {\"role\": \"assistant\", \"content\": assistant_content}\n",
        "    ]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "_zHswJwzWCHZ",
      "metadata": {
        "id": "_zHswJwzWCHZ"
      },
      "outputs": [],
      "source": [
        "def build_jsonl(items):\n",
        "    lines = []\n",
        "    for item in items:\n",
        "        payload = {\"messages\": compose_messages(item)}\n",
        "        lines.append(json.dumps(payload))\n",
        "\n",
        "    return \"\\n\".join(lines)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "rSHYkQojWH8Q",
      "metadata": {
        "id": "rSHYkQojWH8Q"
      },
      "outputs": [],
      "source": [
        "train_jsonl = storage_dir / \"balanced_pricer_train.jsonl\"\n",
        "validation_jsonl = storage_dir / \"balanced_pricer_validation.jsonl\"\n",
        "train_jsonl.write_text(build_jsonl(fine_tune_train))\n",
        "validation_jsonl.write_text(build_jsonl(fine_tune_validation))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "37BH0u-QWOiY",
      "metadata": {
        "id": "37BH0u-QWOiY"
      },
      "outputs": [],
      "source": [
        "openai_client = OpenAI()\n",
        "\n",
        "with open(train_jsonl, \"rb\") as f:\n",
        "    train_file = openai_client.files.create(file=f, purpose=\"fine-tune\")\n",
        "\n",
        "with open(validation_jsonl, \"rb\") as f:\n",
        "    validation_file = openai_client.files.create(file=f, purpose=\"fine-tune\")\n",
        "\n",
        "train_file, validation_file"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2nNSE_AzWYMq",
      "metadata": {
        "id": "2nNSE_AzWYMq"
      },
      "outputs": [],
      "source": [
        "wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"balanced-pricer\"}}\n",
        "fine_tune_job = openai_client.fine_tuning.jobs.create(\n",
        "    training_file=train_file.id,\n",
        "    validation_file=validation_file.id,\n",
        "    model=\"gpt-4o-mini-2024-07-18\",\n",
        "    seed=123,\n",
        "    hyperparameters={\"n_epochs\": 1},\n",
        "    integrations=[wandb_integration],\n",
        "    suffix=\"balanced-pricer\"\n",
        ")\n",
        "fine_tune_job"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ASiJUw-Fh8Ul",
      "metadata": {
        "id": "ASiJUw-Fh8Ul"
      },
      "outputs": [],
      "source": [
        "job_status = openai_client.fine_tuning.jobs.retrieve(fine_tune_job.id)\n",
        "job_events = openai_client.fine_tuning.jobs.list_events(fine_tuning_job_id=fine_tune_job.id, limit=10)\n",
        "job_status, job_events"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7jB_7gqBiH_r",
      "metadata": {
        "id": "7jB_7gqBiH_r"
      },
      "outputs": [],
      "source": [
        "fine_tuned_model_name = openai_client.fine_tuning.jobs.retrieve(fine_tune_job.id).fine_tuned_model\n",
        "print(fine_tuned_model_name)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "BHfLSadhiVQE",
      "metadata": {
        "id": "BHfLSadhiVQE"
      },
      "outputs": [],
      "source": [
        "def tuned_predictor(item):\n",
        "    messages = compose_messages(item, include_price=False)\n",
        "    response = openai_client.chat.completions.create(\n",
        "        model=fine_tuned_model_name,\n",
        "        messages=messages,\n",
        "        seed=123,\n",
        "        max_tokens=7\n",
        ")\n",
        "    answer = response.choices[0].message.content\n",
        "    return get_price(answer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "C0CiTZ4jkjrI",
      "metadata": {
        "id": "C0CiTZ4jkjrI"
      },
      "outputs": [],
      "source": [
        "if test_items:\n",
        "    sample_item = test_items[0]\n",
        "    print(sample_item.price)\n",
        "    print(tuned_predictor(sample_item))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "WInQE0ObkuBl",
      "metadata": {
        "id": "WInQE0ObkuBl"
      },
      "outputs": [],
      "source": [
        "Tester.test(tuned_predictor, test_items[:250])"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "env",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.13.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
